{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 1,
            "id": "1ecd677a",
            "metadata": {},
            "outputs": [],
            "source": [
                "# only run this if your have an editable install\n",
                "%load_ext autoreload\n",
                "%autoreload 2"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "id": "25157da5",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "Found cached dataset fiqa (/home/jjmachan/.cache/huggingface/datasets/explodinggradients___fiqa/main/1.0.0/3dc7b639f5b4b16509a3299a2ceb78bf5fe98ee6b5fee25e7d5e4d290c88efb8)\n"
                    ]
                },
                {
                    "data": {
                        "text/plain": [
                            "Dataset({\n",
                            "    features: ['question', 'ground_truths'],\n",
                            "    num_rows: 648\n",
                            "})"
                        ]
                    },
                    "execution_count": 2,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "from datasets import load_dataset\n",
                "\n",
                "fiqa_test = load_dataset(\"explodinggradients/fiqa\", \"main\", split=\"test\")\n",
                "fiqa_test"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "08739aad",
            "metadata": {},
            "source": [
                "## k=1, chunk_size=300\n",
                "\n",
                "We know the performance of the baseline model so lets try and see if we can make improvements to that."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "id": "2fd6b995",
            "metadata": {},
            "outputs": [],
            "source": [
                "from llama_index import GPTVectorStoreIndex, MockEmbedding\n",
                "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
                "from llama_index import LangchainEmbedding, ServiceContext, StorageContext\n",
                "\n",
                "# load in HF embedding model from langchain\n",
                "embed_model = LangchainEmbedding(HuggingFaceEmbeddings())\n",
                "hf_sc = ServiceContext.from_defaults(embed_model=embed_model)\n",
                "\n",
                "# mock embeddings\n",
                "embed_model = MockEmbedding(embed_dim=1536)\n",
                "mock = ServiceContext.from_defaults(embed_model=embed_model)\n",
                "\n",
                "# openai embeddings\n",
                "openai_sc = ServiceContext.from_defaults()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "id": "6ad83a2c",
            "metadata": {},
            "outputs": [],
            "source": [
                "# load the index\n",
                "from llama_index import StorageContext, load_index_from_storage, ServiceContext\n",
                "\n",
                "# rebuild storage context\n",
                "storage_context = StorageContext.from_defaults(persist_dir=\"./storage\")\n",
                "\n",
                "# load index\n",
                "index = load_index_from_storage(storage_context)\n",
                "\n",
                "# query with embed_model specified\n",
                "qe = index.as_query_engine(\n",
                "    mode=\"embedding\", verbose=True, service_context=openai_sc, use_async=False\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "id": "c149a4be",
            "metadata": {},
            "outputs": [],
            "source": [
                "from llama_index import (\n",
                "    GPTVectorStoreIndex,\n",
                "    ResponseSynthesizer,\n",
                ")\n",
                "from llama_index.retrievers import VectorIndexRetriever\n",
                "from llama_index.query_engine import RetrieverQueryEngine\n",
                "from llama_index.indices.postprocessor import SimilarityPostprocessor\n",
                "\n",
                "# configure retriever\n",
                "retriever = VectorIndexRetriever(\n",
                "    index=index,\n",
                "    similarity_top_k=1,\n",
                ")\n",
                "\n",
                "# configure response synthesizer\n",
                "response_synthesizer = ResponseSynthesizer.from_args(\n",
                "    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)]\n",
                ")\n",
                "\n",
                "# assemble query engine\n",
                "qe = RetrieverQueryEngine(\n",
                "    retriever=retriever,\n",
                "    response_synthesizer=response_synthesizer,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "id": "3df5d725",
            "metadata": {},
            "outputs": [],
            "source": [
                "def generate_response(row):\n",
                "    r = qe.query(row[\"question\"])\n",
                "    row[\"answer\"] = r.response\n",
                "    row[\"contexts\"] = [sn.node.text for sn in r.source_nodes]\n",
                "\n",
                "    return row\n",
                "\n",
                "\n",
                "# generate_response(test_ds[0])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "id": "1d5afb7c",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "text/plain": [
                            "Dataset({\n",
                            "    features: ['question', 'ground_truths', 'answer', 'contexts'],\n",
                            "    num_rows: 30\n",
                            "})"
                        ]
                    },
                    "execution_count": 9,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "gen_ds = fiqa_test.select(range(30)).map(generate_response)\n",
                "gen_ds"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 10,
            "id": "f93d7d23",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/jjmachan/miniconda3/envs/bench/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
                        "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
                        "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
                        "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
                        "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
                        "  warnings.warn(\n",
                        "100%|█████████████████████████████████████████████████████████████| 2/2 [01:18<00:00, 39.17s/it]\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.51s/it]\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.51s/it]\n"
                    ]
                },
                {
                    "data": {
                        "text/plain": [
                            "{'NLI_score': 0.8822222222222222, 'answer_relevancy': 0.8647333333333332, 'context_ relevancy': 0.8236333333333333, 'ragas_score': 0.8561498126750564}"
                        ]
                    },
                    "execution_count": 10,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# evaluate\n",
                "from ragas.metrics import factuality, answer_relevancy, context_relevancy\n",
                "from ragas import evaluate\n",
                "\n",
                "evaluate(gen_ds, metrics=[factuality, answer_relevancy, context_relevancy])"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "43b21990",
            "metadata": {},
            "source": [
                "## k=1, chunk=100"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "id": "7a07ca40",
            "metadata": {},
            "outputs": [],
            "source": [
                "# load the index\n",
                "from llama_index import StorageContext, load_index_from_storage, ServiceContext\n",
                "\n",
                "# rebuild storage context\n",
                "storage_context = StorageContext.from_defaults(persist_dir=\"./storage\")\n",
                "\n",
                "# load index\n",
                "index = load_index_from_storage(storage_context)\n",
                "\n",
                "# query with embed_model specified\n",
                "qe = index.as_query_engine(\n",
                "    mode=\"embedding\", verbose=True, service_context=openai_sc, use_async=False\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "id": "a2f5e91e",
            "metadata": {},
            "outputs": [],
            "source": [
                "from llama_index import (\n",
                "    GPTVectorStoreIndex,\n",
                "    ResponseSynthesizer,\n",
                ")\n",
                "from llama_index.retrievers import VectorIndexRetriever\n",
                "from llama_index.query_engine import RetrieverQueryEngine\n",
                "from llama_index.indices.postprocessor import SimilarityPostprocessor\n",
                "\n",
                "# configure retriever\n",
                "retriever = VectorIndexRetriever(\n",
                "    index=index,\n",
                "    similarity_top_k=1,\n",
                ")\n",
                "\n",
                "# configure response synthesizer\n",
                "response_synthesizer = ResponseSynthesizer.from_args(\n",
                "    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)]\n",
                ")\n",
                "\n",
                "# assemble query engine\n",
                "qe = RetrieverQueryEngine(\n",
                "    retriever=retriever,\n",
                "    response_synthesizer=response_synthesizer,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "id": "f91db468",
            "metadata": {},
            "outputs": [],
            "source": [
                "def generate_response(row):\n",
                "    r = qe.query(row[\"question\"])\n",
                "    row[\"answer\"] = r.response\n",
                "    row[\"contexts\"] = [sn.node.text for sn in r.source_nodes]\n",
                "\n",
                "    return row\n",
                "\n",
                "\n",
                "# generate_response(test_ds[0])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 7,
            "id": "661ad12b",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "text/plain": [
                            "Dataset({\n",
                            "    features: ['question', 'ground_truths', 'answer', 'contexts'],\n",
                            "    num_rows: 30\n",
                            "})"
                        ]
                    },
                    "execution_count": 7,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "gen_ds = fiqa_test.select(range(30)).map(generate_response)\n",
                "gen_ds"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "id": "96e08092",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "46e26286ecbc4a0891f8ee228898ca20",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Downloading model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 2/2 [00:56<00:00, 28.39s/it]\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.04s/it]\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 1/1 [00:18<00:00, 18.73s/it]\n"
                    ]
                },
                {
                    "data": {
                        "text/plain": [
                            "{'ragas_score': 0.8386, 'factuality': 0.8289, 'answer_relevancy': 0.8646, 'context_ relevancy': 0.8236}"
                        ]
                    },
                    "execution_count": 8,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# evaluate\n",
                "from ragas.metrics import factuality, answer_relevancy, context_relevancy\n",
                "from ragas import evaluate\n",
                "\n",
                "evaluate(gen_ds, metrics=[factuality, answer_relevancy, context_relevancy])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 14,
            "id": "87054feb",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "da6babe02adf49369a6d708487eeb068",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "text/plain": [
                            "82699"
                        ]
                    },
                    "execution_count": 14,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# save to fiqa hub\n",
                "import os\n",
                "\n",
                "path_to_dataset = \"../../../../datasets/fiqa/\"\n",
                "\n",
                "gen_ds.to_csv(os.path.join(path_to_dataset, \"baseline_chunk100_k1.csv\"))"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "8b957401",
            "metadata": {},
            "source": [
                "## Cohere Rerank"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 16,
            "id": "15f4c130",
            "metadata": {},
            "outputs": [],
            "source": [
                "from llama_index.indices.postprocessor.cohere_rerank import CohereRerank\n",
                "import os\n",
                "\n",
                "top_k = 4\n",
                "cohere_rerank = CohereRerank(api_key=os.environ[\"COHERE_API_KEY\"], top_n=top_k)\n",
                "reranking_qe = index.as_query_engine(\n",
                "    similarity_top_k=top_k,\n",
                "    node_postprocessors=[cohere_rerank],\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 17,
            "id": "6a73b189",
            "metadata": {},
            "outputs": [],
            "source": [
                "def generate_response(row):\n",
                "    r = reranking_qe.query(row[\"question\"])\n",
                "    row[\"answer\"] = r.response\n",
                "    row[\"contexts\"] = [sn.node.text for sn in r.source_nodes]\n",
                "\n",
                "    return row\n",
                "\n",
                "\n",
                "# generate_response(fiqa_test[0])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 18,
            "id": "32bd4281",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "Parameter 'function'=<function generate_response at 0x7f832ef4f010> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "data": {
                        "text/plain": [
                            "Dataset({\n",
                            "    features: ['question', 'ground_truths', 'answer', 'contexts'],\n",
                            "    num_rows: 30\n",
                            "})"
                        ]
                    },
                    "execution_count": 18,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "gen_ds = fiqa_test.select(range(30)).map(generate_response)\n",
                "gen_ds"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 24,
            "id": "6747689f",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 2/2 [01:42<00:00, 51.01s/it]\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.76s/it]\n"
                    ]
                },
                {
                    "data": {
                        "application/vnd.jupyter.widget-view+json": {
                            "model_id": "",
                            "version_major": 2,
                            "version_minor": 0
                        },
                        "text/plain": [
                            "Map:   0%|          | 0/30 [00:00<?, ? examples/s]"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                },
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "100%|█████████████████████████████████████████████████████████████| 4/4 [01:00<00:00, 15.05s/it]\n"
                    ]
                },
                {
                    "data": {
                        "text/plain": [
                            "{'NLI_score': 0.8805555555555556, 'answer_relevancy': 0.8766333333333333, 'context_ relevancy': 0.8163833333333333, 'ragas_score': 0.8568272415550885}"
                        ]
                    },
                    "execution_count": 24,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# evaluate\n",
                "from ragas.metrics import factuality, answer_relevancy, context_relevancy\n",
                "from ragas import evaluate\n",
                "\n",
                "evaluate(gen_ds, metrics=[factuality, answer_relevancy, context_relevancy])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "id": "4301895f",
            "metadata": {},
            "outputs": [
                {
                    "ename": "NameError",
                    "evalue": "name 'gen_ds' is not defined",
                    "output_type": "error",
                    "traceback": [
                        "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
                        "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
                        "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m evals[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcohere_reranked\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mgen_ds\u001b[49m\n\u001b[1;32m      2\u001b[0m evals\n",
                        "\u001b[0;31mNameError\u001b[0m: name 'gen_ds' is not defined"
                    ]
                }
            ],
            "source": [
                "evals[\"cohere_reranked\"] = gen_ds\n",
                "evals"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "02cb461c",
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3 (ipykernel)",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.10.12"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 5
}
